import torchinfo
from models.net import LeNet  # 确保从正确的模块导入LeNet

if __name__ == "__main__":
    # 实例化网络模型
    model = LeNet()  # 正确的实例化方法
    # 打印网络模型信息
    print(model)

    # 使用torchinfo获取更详细的模型信息
    # 假设输入数据的形状为 (batch_size, channels, height, width)
    # 例如，对于MNIST数据集，输入形状为 (1, 28, 28)
    batch_size = 1
    input_shape = (batch_size, 1, 28, 28)
    summary = torchinfo.summary(model, input_shape)
    print(summary)

